{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ONNX HardSwish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from pathlib import Path\n",
    "ROOT = Path(\".\").resolve().parents[5]\n",
    "# print(ROOT)\n",
    "sys.path.extend([f\"{ROOT}/tests\"])\n",
    "# from tools.tag_span import _create_span, _set_span, _verify_structural_equal_with_span\n",
    "import tools\n",
    "from d2py.utils.file import mkdir\n",
    "root_dir = \".temp\"\n",
    "mkdir(root_dir )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exported graph: graph(%data : Float(1, 1000, strides=[1000, 1], requires_grad=0, device=cpu)):\n",
      "  %/Constant_output_0 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={3}, onnx_name=\"/Constant\"](), scope: __main__.M:: # /tmp/ipykernel_444618/1751548276.py:9:0\n",
      "  %/Add_output_0 : Float(1, 1000, strides=[1000, 1], requires_grad=0, device=cpu) = onnx::Add[onnx_name=\"/Add\"](%data, %/Constant_output_0), scope: __main__.M:: # /tmp/ipykernel_444618/1751548276.py:9:0\n",
      "  %/Constant_1_output_0 : Float(device=cpu) = onnx::Constant[value={0}, onnx_name=\"/Constant_1\"](), scope: __main__.M:: # /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/functional.py:1551:0\n",
      "  %/Constant_2_output_0 : Float(device=cpu) = onnx::Constant[value={6}, onnx_name=\"/Constant_2\"](), scope: __main__.M:: # /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/functional.py:1551:0\n",
      "  %/Clip_output_0 : Float(1, 1000, strides=[1000, 1], requires_grad=0, device=cpu) = onnx::Clip[onnx_name=\"/Clip\"](%/Add_output_0, %/Constant_1_output_0, %/Constant_2_output_0), scope: __main__.M:: # /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/functional.py:1551:0\n",
      "  %/Constant_3_output_0 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={6}, onnx_name=\"/Constant_3\"](), scope: __main__.M:: # /tmp/ipykernel_444618/1751548276.py:9:0\n",
      "  %/Div_output_0 : Float(1, 1000, strides=[1000, 1], requires_grad=0, device=cpu) = onnx::Div[onnx_name=\"/Div\"](%/Clip_output_0, %/Constant_3_output_0), scope: __main__.M:: # /tmp/ipykernel_444618/1751548276.py:9:0\n",
      "  %output : Float(1, 1000, strides=[1000, 1], requires_grad=0, device=cpu) = onnx::Mul[onnx_name=\"/Mul\"](%data, %/Div_output_0), scope: __main__.M:: # /tmp/ipykernel_444618/1751548276.py:9:0\n",
      "  return (%output)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torch import nn\n",
    "from torch.onnx import OperatorExportTypes, utils\n",
    "\n",
    "\n",
    "class M(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x * (F.hardtanh(x + 3, 0., 6.) / 6.)\n",
    "\n",
    "model = M()\n",
    "model.eval()\n",
    "\n",
    "shape = 1, 1000\n",
    "xx = torch.rand(*shape, dtype=torch.float32, requires_grad=False)\n",
    "# model = torch.jit.trace(model, xx)\n",
    "# 导出模型\n",
    "output_name = \"test\"\n",
    "utils.export(\n",
    "    model,               # torch 模型\n",
    "    xx,                         # 模型输入或者对于多个输入，使用元组\n",
    "    f\"{root_dir}/{output_name}.onnx\",               # 模型保存的位置（可以是文件或类似文件的对象）\n",
    "    export_params=True,        # 将训练后的参数权重存储在模型文件内\n",
    "    opset_version=17,          # 导出模型的 ONNX 版本\n",
    "    do_constant_folding=True,  # 是否执行常量折叠以进行优化\n",
    "    input_names = ['data'],    # 模型的输入名称\n",
    "    output_names = ['output'], # 模型的输出名称\n",
    "    keep_initializers_as_inputs=True,\n",
    "    # export_modules_as_functions=True,\n",
    "    verbose=True,\n",
    "    operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,\n",
    "    # dynamic_axes={'data' : {0 : 'batch_size'},    # 可变长度的轴\n",
    "    #               'output' : {0 : 'batch_size'}}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
